#!/usr/bin/env python3
import math
import sys
import numpy as np

import argparse
import torch
import cv2
import pyzed.sl as sl
import torch.backends.cudnn as cudnn

sys.path.insert(0, './yolov5')
from models.experimental import attempt_load
from utils.general import check_img_size, non_max_suppression, scale_coords, xyxy2xywh
from utils.torch_utils import select_device
from utils.augmentations import letterbox
from utils.plots import Annotator
from threading import Lock, Thread
from time import sleep

lock = Lock()
run_signal = False
exit_signal = False


def img_preprocess(img, device, half, net_size):
    img0 = img
    net_image, ratio, pad = letterbox(img[:, :, :3], net_size, auto=False)
    net_image = net_image.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
    net_image = np.ascontiguousarray(net_image)

    img = torch.from_numpy(net_image).to(device)
    img = img.half() if half else img.float()  # uint8 to fp16/32
    img /= 255.0  # 0 - 255 to 0.0 - 1.0

    if img.ndimension() == 3:
        img = img.unsqueeze(0)
    return img, ratio, pad, img0


def xywh2abcd(xywh, im_shape):
    output = np.zeros((4, 2))

    # Center / Width / Height -> BBox corners coordinates
    x_min = (xywh[0] - 0.5 * xywh[2]) * im_shape[1]
    x_max = (xywh[0] + 0.5 * xywh[2]) * im_shape[1]
    y_min = (xywh[1] - 0.5 * xywh[3]) * im_shape[0]
    y_max = (xywh[1] + 0.5 * xywh[3]) * im_shape[0]

    # A ------ B
    # | Object |
    # D ------ C

    output[0][0] = x_min
    output[0][1] = y_min

    output[1][0] = x_max
    output[1][1] = y_min

    output[2][0] = x_min
    output[2][1] = y_max

    output[3][0] = x_max
    output[3][1] = y_max
    return output


def torch_thread(weights, img_size, conf_thres=0.6, iou_thres=0.6):
    global image_net, exit_signal, run_signal, detections, point_cloud

    print("Intializing Network...")

    device = select_device()
    half = device.type != 'cpu'  # half precision only supported on CUDA
    imgsz = img_size

    # Load model
    model = attempt_load(weights, map_location=device)  # load FP32
    stride = int(model.stride.max())  # model stride
    imgsz = check_img_size(imgsz, s=stride)  # check img_size
    if half:
        model.half()  # to FP16
    cudnn.benchmark = True

    # Run inference
    if device.type != 'cpu':
        model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters())))  # run once

    while not exit_signal:
        if run_signal:
            lock.acquire()

            img, ratio, pad = letterbox(image_net[:, :, :3], imgsz, auto=False)
            img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
            img = np.ascontiguousarray(img)
            img = torch.from_numpy(img).to(device)
            img = img.half() if half else img.float()
            img = img / 255.0
            if len(img.shape) == 3:
                img = img[None]
            #############################################
            pred = model(img, augment=False, visualize=False)[0]
            pred = non_max_suppression(pred, conf_thres, iou_thres)

            for i, det in enumerate(pred):
                s, im0 = '', image_net.copy()
                gn = torch.tensor(image_net.shape)[[1, 0, 1, 0]]
                annotator = Annotator(image_net, line_width=2, example=str('A'))

                if len(det):
                    det[:, :4] = scale_coords(img.shape[2:], det[:, :4], image_net.shape).round()
                    for *xyxy, conf, cls in reversed(det):
                        xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()  # normalized xywh
                        cent_x = round(xywh[0] * im0.shape[1])
                        cent_y = round(xywh[1] * im0.shape[0])

                        point_cloud_value = point_cloud.get_value(cent_x, cent_y)[1]
                        point_cloud_value = point_cloud_value * -1000.00
                        if point_cloud_value[2] > 0.00:
                            try:
                                point_cloud_value[0] = round(point_cloud_value[0])
                                point_cloud_value[1] = round(point_cloud_value[1])
                                point_cloud_value[2] = round(point_cloud_value[2])
                                distance = math.sqrt(
                                    point_cloud_value[0] * point_cloud_value[0] + point_cloud_value[1] *
                                    point_cloud_value[1] +
                                    point_cloud_value[2] * point_cloud_value[2])

                                print("x:", point_cloud_value[0], "y:", point_cloud_value[1], "z:",
                                      point_cloud_value[2], "dis:", distance)

                                txt = 'x:{0} y:{1} z:{2} dis:{3}'.format(point_cloud_value[0], point_cloud_value[1],
                                                                       point_cloud_value[2],distance)
                                annotator.box_label(xyxy, txt, color=(255, 0, 255))

                            except:
                                pass
                        im = annotator.result()
                        cv2.imshow('00', im)
                        key = cv2.waitKey(10)
                        if key == 'q':
                            break

            lock.release()
            run_signal = False
        sleep(0.01)


def main():
    global image_net, exit_signal, run_signal, detections, point_cloud

    capture_thread = Thread(target=torch_thread,
                            kwargs={'weights': opt.weights, 'img_size': opt.img_size, "conf_thres": opt.conf_thres})
    capture_thread.start()
    print("Initializing Camera...")
    zed = sl.Camera()
    input_type = sl.InputType()
    if opt.svo is not None:
        input_type.set_from_svo_file(opt.svo)

    # Create a InitParameters object and set configuration parameters
    init_params = sl.InitParameters(input_t=input_type, svo_real_time_mode=True)
    init_params.camera_resolution = sl.RESOLUTION.HD1080
    init_params.coordinate_units = sl.UNIT.METER
    init_params.depth_mode = sl.DEPTH_MODE.ULTRA  # QUALITY
    init_params.coordinate_system = sl.COORDINATE_SYSTEM.RIGHT_HANDED_Y_UP
    init_params.depth_maximum_distance = 10

    runtime_params = sl.RuntimeParameters()
    status = zed.open(init_params)

    if status != sl.ERROR_CODE.SUCCESS:
        print(repr(status))
        exit()

    image_left_tmp = sl.Mat()

    print("Initialized Camera")

    positional_tracking_parameters = sl.PositionalTrackingParameters()
    zed.enable_positional_tracking(positional_tracking_parameters)
    obj_param = sl.ObjectDetectionParameters()
    obj_param.detection_model = sl.DETECTION_MODEL.CUSTOM_BOX_OBJECTS
    obj_param.enable_tracking = True
    zed.enable_object_detection(obj_param)

    objects = sl.Objects()
    obj_runtime_param = sl.ObjectDetectionRuntimeParameters()

    point_cloud_render = sl.Mat()

    point_cloud = sl.Mat()
    image_left = sl.Mat()
    depth = sl.Mat()
    # Utilities for 2D display

    while True and not exit_signal:
        if zed.grab(runtime_params) == sl.ERROR_CODE.SUCCESS:
            # -- Get the image
            lock.acquire()
            zed.retrieve_image(image_left_tmp, sl.VIEW.LEFT)
            image_net = image_left_tmp.get_data()
            zed.retrieve_measure(depth, sl.MEASURE.DEPTH)
            zed.retrieve_measure(point_cloud, sl.MEASURE.XYZRGBA)
            lock.release()
            run_signal = True

            # -- Detection running on the other thread
            while run_signal:
                sleep(0.001)

            # Wait for detections
            lock.acquire()
            # -- Ingest detections
            lock.release()
            zed.retrieve_objects(objects, obj_runtime_param)
            zed.retrieve_image(image_left, sl.VIEW.LEFT)


        else:
            exit_signal = True
    exit_signal = True
    zed.close()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)')
    parser.add_argument('--svo', type=str, default=None, help='optional svo file')
    parser.add_argument('--img_size', type=int, default=416, help='inference size (pixels)')
    parser.add_argument('--conf_thres', type=float, default=0.6, help='object confidence threshold')
    opt = parser.parse_args()
    with torch.no_grad():
        main()
